absl-py
chex
clu
distrax
flax
jax==0.2.9
jaxlib==0.1.60
ml_collections
numpy
optax
tensorflow-cpu
tensorflow-datasets